{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Copyright 2020 NVIDIA Corporation. All Rights Reserved.\n",
    "#\n",
    "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "# you may not use this file except in compliance with the License.\n",
    "# You may obtain a copy of the License at\n",
    "#\n",
    "#     http://www.apache.org/licenses/LICENSE-2.0\n",
    "#\n",
    "# Unless required by applicable law or agreed to in writing, software\n",
    "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "# See the License for the specific language governing permissions and\n",
    "# limitations under the License.\n",
    "# =============================================================================="
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from time import time \n",
    "\n",
    "import fastai\n",
    "from fastai.basics import Learner\n",
    "from fastai.tabular.model import TabularModel\n",
    "from fastai.tabular.data import TabularDataLoaders\n",
    "from fastai.metrics import accuracy\n",
    "from fastai.callback.progress import ProgressCallback\n",
    "\n",
    "import cudf\n",
    "\n",
    "import nvtabular as nvt\n",
    "from nvtabular.ops import Normalize, FillMissing, Categorify, get_embedding_sizes\n",
    "from nvtabular.loader.torch import TorchAsyncItr, DLDataLoader\n",
    "\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "('1.6.0', '0+untagged.1.gfa8e9fb', '2.0.16')"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.__version__, cudf.__version__, fastai.__version__"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext snakeviz\n",
    "# load snakeviz if you want to run profiling"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<h3> Dataset Gathering: Define files in the training and validation datasets. </h3>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# data_path = '/rapids/notebooks/jperez/Documents/ds-itr/examples/'\n",
    "INPUT_DATA_DIR = os.environ.get('INPUT_DATA_DIR', '/raid/outbrain1/preprocessed')\n",
    "OUTPUT_DATA_DIR = os.environ.get('OUTPUT_DATA_DIR', '/raid/outbrain1/output')\n",
    "BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 400000))\n",
    "data_path = INPUT_DATA_DIR\n",
    "#df_test = 'test/'\n",
    "df_valid = os.path.join(data_path, 'validation_feature_vectors_integral.csv/')\n",
    "df_train = os.path.join(data_path, 'train_feature_vectors_integral_eval_imputed.csv/')\n",
    "\n",
    "train_set = [os.path.join(df_train, x) for x in os.listdir(df_train) if x.startswith(\"part\")] \n",
    "valid_set = [os.path.join(df_valid, x) for x in os.listdir(df_valid) if x.startswith(\"part\")]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(196, 4)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(train_set), len(valid_set)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<h4>Grab column information</h4>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "cols = open(os.path.join(data_path, 'train_feature_vectors_integral_eval.csv.header')).read().splitlines()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "cat_names = ['display_id', 'is_leak', 'doc_event_id', 'ad_id', 'doc_id', 'doc_ad_entity_id', 'doc_event_entity_id', 'doc_event_entity_id', 'doc_ad_source_id', 'doc_event_source_id', 'event_geo_location', 'ad_advertiser', 'event_country_state', 'doc_ad_publisher_id', 'doc_event_publisher_id', 'doc_ad_topic_id', 'doc_event_topic_id', 'event_country', 'doc_ad_category_id', 'doc_event_category_id', 'event_hour', 'event_platform', 'traffic_source', 'event_weekend', 'user_has_already_viewed_doc']\n",
    "cont_names =  ['pop_ad_id_conf', 'pop_document_id_conf', 'user_doc_ad_sim_categories_conf', 'user_doc_ad_sim_topics_conf', 'pop_publisher_id_conf', 'pop_advertiser_id_conf', 'pop_campaign_id_conf', 'pop_source_id_conf', 'pop_entity_id_conf', 'pop_topic_id_conf', 'pop_category_id_conf', 'pop_ad_id', 'pop_document_id', 'pop_publisher_id', 'pop_advertiser_id', 'pop_campaign_id', 'pop_source_id', 'pop_entity_id', 'pop_topic_id', 'pop_category_id', 'user_doc_ad_sim_categories', 'user_doc_ad_sim_topics', 'user_doc_ad_sim_entities', 'doc_event_doc_ad_sim_categories', 'doc_event_doc_ad_sim_topics', 'doc_event_doc_ad_sim_entities', 'user_views', 'ad_views', 'doc_views', 'doc_event_days_since_published', 'doc_event_hour', 'doc_ad_days_since_published'] #+ [i for i in ds.columns if i not in cat_names and i not in ['label']]\n",
    "cat_names = [name for name in cat_names if name in cols]\n",
    "cont_names = [name for name in cont_names if name in cols]\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<h3>Preprocessing:</h3> <p>Select operations to perform, create the Preprocessor object, create dataset iterator object and collect the stats on the training dataset</p>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 8 µs, sys: 21 µs, total: 29 µs\n",
      "Wall time: 36.7 µs\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "proc = nvt.Workflow(cat_names=cat_names, cont_names=cont_names, label_name=['label'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "proc.add_cont_preprocess([FillMissing(), Normalize()])\n",
    "proc.add_cat_preprocess(Categorify())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 2.57 ms, sys: 0 ns, total: 2.57 ms\n",
      "Wall time: 2.59 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "trains_itrs = nvt.Dataset(train_set,names=cols, engine='csv',part_size='100MB')\n",
    "valids_itrs = nvt.Dataset(valid_set,names=cols, engine='csv',part_size='200MB')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "output_path_train = os.path.join(OUTPUT_DATA_DIR, 'train')\n",
    "output_path_valid = os.path.join(OUTPUT_DATA_DIR, 'valid')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "proc.apply(trains_itrs, apply_offline=True, record_stats=True, output_path=output_path_train, shuffle=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "proc.apply(valids_itrs, apply_offline=True, record_stats=False, output_path=output_path_valid, shuffle=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "new_train_set = [os.path.join(output_path_train, x) for x in os.listdir(output_path_train) if x.endswith(\"parquet\")]\n",
    "new_valid_set = [os.path.join(output_path_valid, x) for x in os.listdir(output_path_valid) if x.endswith(\"parquet\")]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<h5>Gather embeddings using statistics gathered in the Read phase.</h5>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "embeddings = list(get_embedding_sizes(proc).values())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "embeddings"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<h5>Create the file iterators using the FileItrDataset Class.</h5>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "t_batch_sets = nvt.Dataset(new_train_set, engine='parquet') \n",
    "v_batch_sets = nvt.Dataset(new_valid_set, engine='parquet')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataset = TorchAsyncItr(t_batch_sets, \n",
    "                              batch_size=100000, \n",
    "                              cats=cat_names, \n",
    "                              conts=cont_names, \n",
    "                              labels=[\"label\"])\n",
    "valid_dataset = TorchAsyncItr(v_batch_sets, \n",
    "                              batch_size=50000, \n",
    "                              cats=cat_names, \n",
    "                              conts=cont_names, \n",
    "                              labels=[\"label\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<h5>Use the Deep Learning Collator to create a collate function to pass to the dataloader.</h5>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def gen_col(batch):\n",
    "    return (batch[0], batch[1], batch[2].long())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "t_data = DLDataLoader(train_dataset, collate_fn=gen_col, pin_memory=False, num_workers=0)\n",
    "v_data = DLDataLoader(valid_dataset, collate_fn=gen_col, pin_memory=False, num_workers=0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<h4>After creating the Dataloaders you can leverage fastai framework to create Machine Learning models</h4>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "databunch = TabularDataLoaders(t_data, v_data, collate_fn=gen_col)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "model = TabularModel(emb_szs = embeddings, n_cont=len(cont_names), out_sz=2, layers=[512,256])\n",
    "\n",
    "learn =  Learner(databunch, model, loss_func = torch.nn.CrossEntropyLoss(), metrics=[accuracy], cbs=ProgressCallback())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "        <style>\n",
       "            /* Turns off some styling */\n",
       "            progress {\n",
       "                /* gets rid of default border in Firefox and Opera. */\n",
       "                border: none;\n",
       "                /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
       "                background-size: auto;\n",
       "            }\n",
       "            .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
       "                background: #F44336;\n",
       "            }\n",
       "        </style>\n",
       "      <progress value='0' class='' max='1' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      0.00% [0/1 00:00<00:00]\n",
       "    </div>\n",
       "    \n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: left;\">\n",
       "      <th>epoch</th>\n",
       "      <th>train_loss</th>\n",
       "      <th>valid_loss</th>\n",
       "      <th>accuracy</th>\n",
       "      <th>time</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "  </tbody>\n",
       "</table><p>\n",
       "\n",
       "    <div>\n",
       "        <style>\n",
       "            /* Turns off some styling */\n",
       "            progress {\n",
       "                /* gets rid of default border in Firefox and Opera. */\n",
       "                border: none;\n",
       "                /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
       "                background-size: auto;\n",
       "            }\n",
       "            .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
       "                background: #F44336;\n",
       "            }\n",
       "        </style>\n",
       "      <progress value='90' class='' max='598' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      15.05% [90/598 00:18<01:44 2.4129]\n",
       "    </div>\n",
       "    "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.\n"
     ]
    }
   ],
   "source": [
    "learn.lr_find()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Min numerical gradient: 2.29E-02\n",
      "Min loss divided by 10: 1.00E-02\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAY0AAAEGCAYAAACZ0MnKAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/d3fzzAAAACXBIWXMAAAsTAAALEwEAmpwYAAAk90lEQVR4nO3de3xdZZ3v8c9v536/NW3SpPd7Cy0tBUFFEEcE5Q46gFd05IXojDrqGeecOWdmzpwz44wzOjqiyJwBlHkJKspYrkURLAgIBdrStAHaUmiapEmbNElzz97P+WOthN00l9UmO3vt5Pt+vfLKXpe983u603z3s9Z6nmXOOURERIKIJLsAERFJHQoNEREJTKEhIiKBKTRERCQwhYaIiASWnuwCTtasWbPcwoULk12GiEhKefHFFw8758on+jopFxoLFy5k69atyS5DRCSlmNmbk/E6OjwlIiKBKTRERCQwhYaIiASm0BARkcAUGiIiEphCQ0REAlNoiIhIYAoNEZEU8J3fvM7Trx9OdhkKDRGRsOvui/Kvj7/G1jdbkl2KQkNEJOxeO9SBc7CyojDZpSg0RETCrraxHYCVFQVJrkShISISerWNHeRkpDG/NDfZpSg0RETCrrahgxUVBUQiluxSFBoiImHmnKO2sZ1Vlck/NAUKDRGRUGvq6KW1q58VcxQaIiIyjtrGDgBWVib/yilQaIiIhFptQ3iunAKFhohIqNU2dlBRmE1xbmaySwEUGiIioba7oZ2VITkJDgoNEZHQ6o/G2Nt8LBQjwQcpNEREQmpfcyf9UReay20hgaFhZneYWZOZ7Rxl+0oze9bMes3sq4mqQ0QkVQ1OH7IiJCfBIbE9jbuAi8fY3gL8GfDPCaxBRCRl1TZ2kJFmLJ6Vn+xShiQsNJxzW/CCYbTtTc65F4D+RNUgIpLKahvaWVKeT2Z6eM4khKeSMZjZTWa21cy2Njc3J7scEZEpUdvYwaqQDOoblBKh4Zy73Tm30Tm3sby8PNnliIgkXFtXPw1tPaE6nwEpEhoiIjNNmO6hEU+hISISQoNzToXt8FR6ol7YzO4BLgBmmVkd8NdABoBz7jYzqwC2AoVAzMy+BKx2zrUnqiYRkVRR29hOSW4Gswuykl3KcRIWGs6568fZ3ghUJ+rni4ikst3+jZfMkn/jpXg6PCUiEjKxmOO1Qx2hmj5kkEJDRCRkOnoG6OqLUl2Sk+xSTqDQEBEJmbZub8xzWKZDj6fQEBEJmcHQKMxO2GnnU6bQEBEJmcHQKMrJSHIlJ1JoiIiEzFBo5Co0RERkHOppiIhIYAoNEREJrK27n4w0IycjLdmlnEChISISMm3d/RTlZIRuNDgoNEREQqe9u5/CEB6aAoWGiEjoDPY0wkihISISMgoNEREJTKEhIiKBKTRERCSQWMzR0aPQEBGRAI71DRBz4RzYBwoNEZFQaevyZ7hVaIiIyHjCPIUIKDREREKlXaEhIiJBqachIiKBKTRERCQwhYaIiATW1t1PesTIzQzftOig0BARCZUwT4sOCg0RkVAJ8xQioNAQEQmVthDfSwMUGiIiodKunoaIiAQ1Yw9PmdkdZtZkZjtH2W5m9l0z22NmO8xsQ6JqERFJFd7hqfRklzGqRPY07gIuHmP7JcAy/+sm4AcJrEVEJPScc7T3DMzMnoZzbgvQMsYuVwA/dp7ngGIzq0xUPSIiYXesd4BozM3M0AigCjgQt1znrzuBmd1kZlvNbGtzc/OUFCciMtXCPhockhsaI41ccSPt6Jy73Tm30Tm3sby8PMFliYgkh0JjbHXAvLjlaqA+SbWIiCTdYGhonMbINgGf8K+iOgdoc841JLEeEZGkCvu9NAASdl2Xmd0DXADMMrM64K+BDADn3G3Aw8AHgT1AF3BjomoREUkFqXB4KmGh4Zy7fpztDvh8on6+iEiqSYXQ0IhwEZGQaOvuJy1i5GfNzMF9IiJyEtq6+ynMTg/ttOig0BARCY227nCPBgeFhohIaIR9skJQaIiIhEbY76UBCg0RkdAI+700QKEhIhIaOjwlIiKBOOfU0xARkWC6+qIMhHxadFBoiIiEQiqMBgeFhohIKCg0REQkMIWGiIgElgr30gCFhohIKKinISIigQ3dgClXoSEiIuNo6+4nYpCfGd5p0UGhISISCoPzTkUi4Z0WHRQaIiKh4N1LI9yHpkChISISCqkw7xQoNEREQkGhISIigSk0REQksPYUuAETKDRERJLOOaeehoiIBNPdH6U/Gv5p0UGhISKSdEe7vNHgxSEfDQ4KDRGRpGvp7AOgJDczyZWMT6EhIpJkrV1eaJTlKzRERGQc6mmIiEhgrX5olOYpNEREZBwtXf2Yhf9eGpDg0DCzi83sVTPbY2ZfH2F7iZndb2Y7zOx5MzstkfWIiIRRa2cfxTkZpIV8hlsIGBpmlmdmEf/xcjO73MzGjEQzSwNuBS4BVgPXm9nqYbv9d2Cbc24t8AngOyfbABGRVNfS1UdJChyaguA9jS1AtplVAY8DNwJ3jfOcs4E9zrl9zrk+4F7gimH7rPZfD+dcLbDQzOYErElEZFpo7eyjNAVOgkPw0DDnXBdwNfBvzrmr8P7gj6UKOBC3XOevi7fdf03M7GxgAVB9wg83u8nMtprZ1ubm5oAli4ikhpbO6dfTMDM7F/go8JC/brx7Eo50cM4NW/4GUGJm24A/BV4GBk54knO3O+c2Ouc2lpeXByxZRCQ1tHalTk8j6M1ovwT8JXC/c67GzBYDT4zznDpgXtxyNVAfv4Nzrh3vUBdmZsAb/peIyIzgnKO1sz9lehqBQsM59zvgdwD+CfHDzrk/G+dpLwDLzGwRcBC4DrghfgczKwa6/HMefwJs8YNERGRGONY7QF80Rmle+C+3heBXT/3EzArNLA/YBbxqZl8b6znOuQHgC8BmYDfwM7+XcrOZ3ezvtgqoMbNavKusvniqDRERSUWtnd5khakwGhyCH55a7ZxrN7OPAg8DfwG8CHxzrCc55x72949fd1vc42eBZSdVsYjINNLSlTqjwSH4ifAMf1zGlcCvnHP9nHhSW0RETlIqTSECwUPjh8B+IA/YYmYLAJ17EBGZoJYUC42gJ8K/C3w3btWbZvbexJQkIjJzDE6LnipXTwU9EV5kZt8aHGBnZv+C1+sQEZEJaOnsIz1iFGQFPcWcXEEPT90BdAAf8b/agTsTVZSIyEzR6s875Q1VC7+g0bbEOXdN3PLf+qO4RURkAlpSaN4pCN7T6Dazdw8umNm7gO7ElCQiMnN4o8FTY2AfBO9p3Az82MyK/OVW4JOJKUlEZOZo6epj+Zz8ZJcRWNCrp7YD68ys0F9uN7MvATsSWJuIyLTX2tmXMqPB4STv3Oeca4+bG+rPE1CPiMiMEYs5b4bbFLncFiZ2u9fUONUvIhJSbd39xFzqzDsFEwsNTSMiIjIBqTbvFIxzTsPMOhg5HAzISUhFIiIzxOC8U6kyGhzGCQ3nXMFUFSIiMtMMzjtVlkKhMZHDUyIiMgGpNu8UKDRERJKmxb8B03QcES4iIpOstauP7IwIOZlpyS4lMIWGiEiSpNq8U6DQEBFJmtbOvpQ6nwEKDRGRpGlJsdHgoNAQEUmaVJt3ChQaIiJJ09KpnoaIiATQH43R3jOgnoaIiIyvdWjeqdS5ARMoNEREkqLVH9inq6dERGRcg/NOaZyGiIiMa+jwVL5CQ0RExqGehoiIBDZ4L41ihcbbzOxiM3vVzPaY2ddH2F5kZg+Y2XYzqzGzGxNZj4hIWLR09VGQlU5memp9dk9YtWaWBtwKXAKsBq43s9XDdvs8sMs5tw64APgXM0ut2BUROQWpOO8UJLancTawxzm3zznXB9wLXDFsHwcUmJkB+UALMJDAmkREQqGlq1+hMUwVcCBuuc5fF+97wCqgHngF+KJzLjb8hczsJjPbamZbm5ubE1WviMiUae3sozQ3tQb2QWJDw0ZY54YtfwDYBswFzgC+Z2aFJzzJududcxudcxvLy8snu04RkSnXosNTJ6gD5sUtV+P1KOLdCPzSefYAbwArE1iTiEgotHal3g2YILGh8QKwzMwW+Se3rwM2DdvnLeB9AGY2B1gB7EtgTSIiSdfTH6WrL5qSPY30RL2wc27AzL4AbAbSgDucczVmdrO//Tbg74C7zOwVvMNZf+GcO5yomkREwuDwsV6AlJsWHRIYGgDOuYeBh4etuy3ucT1wUSJrEBEJm2f2HgFgdeUJp3BDL7VGlYiITAMPbK9nfmkua6uLkl3KSVNoiIhMocPHenlm7xEuW1eJN0QttSg0RESm0COvNBCNOS5fN3zYWmpQaIiITKEHtjewfE4+KyoKkl3KKVFoiIhMkYa2bp7f38Jla+cmu5RTptAQEZkiD25vAOCydQoNEREZxwM76llbXcTCWXnJLuWUKTRERKbA/sOd7KhrS+lDU6DQEBFJiGjs+PlZH9zhTb33obWVyShn0iR0RHiYPPlqE//7wV3gvKl2Y87hHDh/4l3nv79pESM9YqRHIqRFjEgE0swwM9IiNrR9aL+0yHHLaZEIaRGGvmempZGdESE7w/uelZ5GepqRkRYhMy1CRlqErPQIWRnecnZGGjmZaeTEfc/NTEvJ67lFZqrtB45yzQ+eYdGsPM5ZXMa5S8r41bZ6zl5YytzinGSXNyEzJjQKczJYVVmIAREzzLzJrszsuDnco84xEHNEo46BWIyY8wImGnND3/ujMbr6vO3RGERjMQZijoGotz0ac0T9ffsGYvT0RxmIDZ8VPjgzyMtMJy8rjbysdAqyMyjISic/K538bO97/LbinAyKczMozsmkODeDkrxM8hQ8IlPmZ1sPkJ5mzC3O4Rcv1XH3c28C8HdXrElyZRM3Y0Jjw/wSNtxQkrSf3x/1wqM/6oVO30CM/miM/qijdyDqh0uM3oEo3f1Ruvui9PRH6eyL0tU7wLHeKJ29AxzrG+BYzwDHegdo6ugZetzZFz2hOxwvMy1CSV4GpXlZlBdkMbvg7e9zCrOZU5jF7IJsZhdmkZWeNoX/MiLTS99AjIdeaeCi1RV89/r19Edj7Khr47VDHVy1PjUH9MWbMaGRbBn+oahEcc7ROxCjvaeftq5+jnb3c7Srn9bOPlq7+mj1Hx/p7KW5o5fXD3XQ3NE7Yg+ovCCL6pIcqopzqC7JZV5pDvNKcplfmsvc4hwy03UqTGQ0T+9p5mhXP1ec4Z3wzkiLcOaCEs5ckLwPrZNJoTFNmJl/3iSN2QXZgZ4Tizlau/po6ujlUHsPTe29NLb3cLC1m7qjXew82Mbmmkb6o28HS8RgXmkuS8rzWTo7nyXleSydXcDyOfkUZKferStFJtumbfUU52Zw3rLpeZdRhcYMFokYZflZlOVnsWqUKZqjMceh9h4OtHRxoLWbt450svdwJ3ubjvH0nsP0Dbx9S/e5Rdksm1PAysoCVlcWsrqykEWz8khPYA9LJEy6+gZ4bNchrjijatr2yBUaMqa0iHcyb25xDu8Yti0ac9S1dvH6oWO8eqiD1w918OqhYzyz9/BQ7yQrPcLKigJWzy1izdxCVs/1wiQ7Q+dNZPr5ze4muvqiQ4empiOFhpyytIixoCyPBWV5/NHqOUPr+wZi7G0+xu6Gdmrq29lV385DO+q55/m3AEiPGCsqClg3r5gzqotZN6+YpbPzSYvo6i5JbZu2HaSiMJuzF5Ymu5SEUWjIpMtMj7CqspBVlYVcvcFb55yjrrWbmvp2Xjl4lO0H2nhgez0/+YMXJHmZaZxeXXRckFQWZesyYUkZrZ19PPlqM59+9yIi0/gDkEJDpoSZMa80l3mluVx8WgXgnYjfd7iTHXVH2X7gKNvq2rjz6f30Rb3zJLPyszhjXhFrq4tZW13E6VVFlOVnJbMZIqN6ZGcjAzHH5Sk8GWEQCg1JmkjEWDrbuwrr6g3VAPQORKlt6GBH3VG2HWhj24FWHq9tGhqxX1Wcw+lVb58fWTO3iDmFWeqRSNJt2n6QJeV5rJmbevf9PhkKDQmVrPQ01s3zDk99/FxvXUdPPzsP+oe16tqoOdjGozWNQ88pzctkZUUBKysKWVlZwIo5BSyclUdRji4BlqnR0NbNH95o4ct/tHzaf4BRaEjoFWRncO4Sb/6eQR09/dQ2dlBzsI3dDR3UNrbzk+ffpKf/7UuAS/MyWTQrjwWluVT5gxUHv88tztEVXDJpHnmlEedS+z4ZQSk0JCUVZGdw1sJSzoq7SiUac7x5pJPXm46x/3An+4908sbhTp7bd4TG9h6GD36flZ/pXU5clONfVpxNdUnO0CXGZXmZ0/5To0yOzTWNrJhTwKIUvk9GUAoNmTbSIsbi8nwWl+efsK0/GqOxrYeDR7s52NpNQ1s3B4/2UH+0mz3Nx9jyejNdfdHjnpOVHhnqnQyfTmVBWS7FuZlT1TQJsZbOPl7Y38Ln37s02aVMCYWGzAgZaZGhq7dG4pyjrbt/KFTqj3Zz8Gg39Ud7qGvtYnN9Iy2dfcc9pygngwVluSwoy2OpP63Ksjn5LCzLm7ajgeVEv9l9iJiDD6ypSHYpU0KhIYJ3SXBxbibFuZmsmVs04j6dvQMcaO3irSNdvNXSxZtHuth/pJOX32rlwR31x92TZdGsPFZUeCflV1R406pUl+TocNc09FhNI1XFOdP+qqlBCg2RgPKy0r0rtCpO/OPQ3Rdlb/Mx9jQd4/WmDl5tPMaOuqM8tKNhaJ+CrHRWVb49lcrquYUsn1OgXkkK6+wdYMvrh7nh7Pkz5gOBQkNkEuRkpnFaVRGnVR3fS+nsHeC1Qx3sbuhgV0Mbu+rb+dnWA0PnTzLSjGWzC1gzt9Afe1LEqsoCzRicIra81kzfQGzGHJoChYZIQuVlpbN+fgnr5799L4VozLH/SCe76r25uWrq2/htbRM/f7FuaJ8FZbmsqij0p2MpYE1VEXM1rUrobK5ppCQ3g7MWTo97ZQSR0NAws4uB7wBpwP9zzn1j2PavAR+Nq2UVUO6ca0lkXSLJlBYxlpTns6Q8f+i6fuccTR29fpC0sauhnd0NHWze1Th0rqQkN4PTqopYM7eItdVFbJhfQkVRsHunyOTrj8Z4vLaJi9dUzKjp/xMWGmaWBtwKvB+oA14ws03OuV2D+zjnvgl809//MuDLCgyZiczMv+1uNu9dOXtofVffgDeIsb6dmoNt7Kxv4z+e3jc09XxlUTYb5nt3hTtncRkrKwqm9WR5YfLcviN09Axw0Qw6NAWJ7WmcDexxzu0DMLN7gSuAXaPsfz1wTwLrEUk5uZnp3v3t4w5v9Q5E2d3QwUtvtvLSW628/NZRHnrFO+FenJvB2QtLOWdxGe9cWsby2QqRRNlc00huZhrnLZuV7FKmVCJDowo4ELdcByfcxwcAM8sFLga+kMB6RKaFrPQ0zphXzBnzivk0iwA4eLSbP+w7wnP7jvDsviM8tusQAGV5mZyzpIx3LinjvKXlzC8beZyKnJxYzPFYzSHOX14+46ajSWRojPTxxo2wDuAy4PejHZoys5uAmwDmz58/OdWJTCNVxTlcvaF6aLbgutYunt17hGf3HuGZvUeGLv2dX5rLu5fN4j3LZvHuZeXkZ+lamFOxre4oTR29XLRmzvg7TzOJ/I2pA+bFLVcD9aPsex1jHJpyzt0O3A6wcePG0YJHRHzVJbl8eGMuH944D+e8+5Y8/fphnnr9MJu2eTe/ykgzzlpYyoUrZ3PhytkjTr8iI/uvlw+SmR7hwpUzLzTMucT8DTazdOA14H3AQeAF4AbnXM2w/YqAN4B5zrnO8V5348aNbuvWrQmoWGRm6I/GeOnNVn77ahNP1Dbx2qFjACydnc9Fq+dw0ZoK1lYV6VzIKHr6o5z9f3/De1fO5jvXrU92OYGZ2YvOuY0TfZ2E9TSccwNm9gVgM94lt3c452rM7GZ/+23+rlcBjwUJDBGZuIy0CO9YXMY7Fpfxl5es4kBLF4/vPsRjuw7xwy37+P6Te6kozOaDp1dy6bpK1s8r1viQOI/ubKS9Z4A/3jhv/J2noYT1NBJFPQ2RxDna1cdva5t4ZGcjv3u1mb5ojKriHC5dW8mV66tYVTkz5lcay/W3P0fd0S5+99X3plRvLPQ9DRFJPcW5mUMn1Nt7+vl1zSEe3FHPfzz9Bj/cso+VFQVcvaGKK86oYk7hzBtY+OaRTp7dd4SvXrQ8pQJjMik0RGREhdkZXHNmNdecWU1LZx8P7qjnly8d5O8fruUbj9RywYrZXHfWPC5cOXvGjIj+2dYDRAyuPXNmHpoChYaIBFCal8knzl3IJ85dyL7mY/zipTp+vrWOm2qbmF2QxYc3VnPDOxZQVZyT7FITZiAa474X67hgxewZPX3LzPh4ICKTZnF5Pl/7wEqe+fqF/PsnNnJ6VRE/eHIv5/3jb7n57hd5bt8RUu1caRC/e62ZQ+29fGSGngAfpJ6GiJyS9LQI7189h/evnkNdaxf/+dxb3PvCWzxa08jKigI+e95iLj9jLhnT5NDVT184wKz8TN63avb4O09j0+PdFJGkqi7J5euXrOS5v3wf/3TNWgC+8vPtXPDNJ7nz92/Q1TeQ5Aonpqmjh8drm7hmQ/W0CcFTNbNbLyKTKjsjjY+cNY9Hvnged37qLKqKc/jbB3bxrm/8lu8/uYfO3mHhsXcv3HILFBZCJOJ9v+UWb32I3P/SQaIxx4dn+KEp0DgNEUmwrftbuPWJPTzxajOleZncfP5iPn7OQnIefwyuvRb6+72vQRkZ3td998EllySv8DgX/+sWsjPS+K/PvyvZpZyyyRqnoZ6GiCTUxoWl3Hnj2dx/yzs5raqIv3+4luu+/p/0X30NdHUdHxjgLXd1eYESgh7Hrvp2ahs7uHpDVbJLCQWFhohMifXzS/jxp8/mF587l8+/tAnX1zf2E/r74dvfnprixnD/y3WkR4xL185NdimhoNAQkSl15oJS3v/Sr8mMRcfesb8f7r57aooaRTTm+NW2ei5YMZvSvMyk1hIWCg0RmXJ27FiwHYPulyC/33OYpo5eHZqKo9AQkamXH/DeHUH3S5D/evkgBdnpXLhyZo/NiKfQEJGp97GPeVdIjaEvksaj69/P5ppGorGpv8qzq2+AR2sa+dDplTPulq5j0YhwEZl6X/kK/OhHJ145FccyM7njrCt5/u4XWVCWy+Xr5nLmghI2LCihMHvswJkMm2sa6eqLctV6HZqKp9AQkam3ZIk3DmOMcRoZ993HTy76AJtrDnHXM29w6xN7iDkwgxVzCrhgxWyuPbOKpbMLElLiL186SFVxDmctLE3I66cqhYaIJMcll8COHd5ltXff7Z30zs+Hj38cvvxlWLKEdOBDayv50NpKOnsH2HbgKFv3t/L8/iP8+1P7uO13e1k3r5hrN1Rx6dq5lEzSFU5N7T38fs9hbrlg6Yy9b8ZoNCJcRFJSc0cvv9p2kPterKO2sYO0iHHu4jIuPq2Ci9bMYXbBqU9f/o+P1vKDJ/fy+FfOZ0l5ck/GT5bJGhGu0BCRlOaco6a+nYdeaeDRnY28cbgTM9i4oISLVlfw/tVzWDgrL/DrPVHbxKd/9AJXra/iWx85I3GFTzGFhojIMM45Xj3UwSOvNPLYrkPsbmgHvHMgHzitgqvWV7FojAB580gnl/3b01SX5PKLz72TnMzpc9WUQkNEZBwHWrr49a5DbK5p5Pn9LTgHZ8wr5uoNVXzo9ErK8rOG9u3ui3LV939PQ1sPD/7pu5lXmpvEyiefQkNE5CQ0tvWwaftB7n+5nt0N7aRFjHcsKvXOgayu4B8e2c2m7fXcdePZnL+8PNnlTjqFhojIKdrd0M7DrzTwyM5G9jS9PVXJVy9azhcuXJbEyhJnskJDl9yKyIyzqrKQVZWFfOWiFexp6uDRnY30RR23XLA02aWFnkJDRGa0pbML+MKFiRkgOB1p7ikREQlMoSEiIoEpNEREJDCFhoiIBKbQEBGRwBQaIiISmEJDREQCU2iIiEhgKTeNiJk1A28OW10EtI2zbqzlwcfx62YBh0+xzJHqCbrPybZlvMcTacdYdQbZHqa2TOQ9GWnbTPn9Gr48vC2J/v0aa5/p/Ps10rqJtmWBc27ik2o551L+C7h9vHVjLQ8+HrZu62TWE3Sfk23LeI8n0o4gbRlre5jaMpH35GR/n6bT79d4bUn079dktiWVfr+S2ZbxvqbL4akHAqwba/mBUfaZzHqC7nOybQnyeCLGe52xtoepLRN5T0baNlN+v4Yvp3JbUun3a6R1U/n/flQpd3hqqpjZVjcJM0Im23RpB6gtYTRd2gFqS1DTpaeRCLcnu4BJMl3aAWpLGE2XdoDaEoh6GiIiEph6GiIiEphCQ0REApv2oWFmd5hZk5ntPIXnnmlmr5jZHjP7rplZ3LaPmNkuM6sxs59MbtWj1jPpbTGzT5lZs5lt87/+ZPIrH7GehLwv/vZrzcyZWcJPaiboPbnZX7/NzJ42s9WTX/mI9SSiLX/u/z/ZYWaPm9mCya98xHoS0Zb3mNlLZjZgZtdOftXH1XDK9Y/yep80s9f9r0/GrV9kZn/w1//UzDLHfbFEXcsbli/gPcAGYOcpPPd54FzAgEeAS/z1y4CXgRJ/eXYKt+VTwPemw/vibysAtgDPARtTsR1AYdw+lwOPpup7ArwXyPUffw74aQq3ZSGwFvgxcG0Y6weeBBYOW1cK7PO/l/iPB/92/Qy4zn98G/C58X7GtO9pOOe2AC3x68xsiZk9amYvmtlTZrZy+PPMrBLvP++zzvsX/TFwpb/5s8CtzrlW/2c0JbQRvgS1JSkS2Ja/A/4J6Elc9W9LRDucc+1xu+YBU3K1SoLa8oRzrsvf9TmgOqGN8CWoLfudczuAWFjrH8UHgF8751r8v1m/Bi72e1AXAvf5+/2IAH8Xpn1ojOJ24E+dc2cCXwW+P8I+VUBd3HKdvw5gObDczH5vZs+Z2cUJrXZsE20LwDX+4YP7zGxe4kod14TaYmbrgXnOuQcTXeg4JvyemNnnzWwvXgD+WQJrHc9k/H4N+gzeJ/dkmcy2JEOQ+kdSBRyIWx5sUxlw1Dk3MGz9mNIDlztNmFk+8E7g53GHwrNG2nWEdYOf+NLxDlFdgPfJ6SkzO805d3RSix3HJLXlAeAe51yvmd2M92njwsmudTwTbYuZRYBv4x1uS5pJek9wzt0K3GpmNwB/BXxyhP0TarLa4r/Wx4CNwPmTWWNQk9mWZBirfjO7Efiiv24p8LCZ9QFvOOeuYvQ2nVJbZ1xo4PWujjrnzohfaWZpwIv+4ibgBxzfla4G6v3HdcBzzrl+4A0zexUvRF5IYN0jmXBbnHNH4tb/O/CPiSp2HBNtSwFwGvCk/5+qAthkZpc757YmtvTjTMbvV7x7/X2TYVLaYmZ/BPwP4HznXG8iCx7DZL8vU23E+gGcc3cCdwKY2ZPAp5xz++N2qcP7gDuoGu/cx2Gg2MzS/d5GsLYm8mROWL7wTmDtjFt+Bviw/9iAdaM87wXgHN4+IfZBf/3FwI/8x7Pwun5lKdqWyrh9rsILw5R8X4bt8yRTcCI8Qe/Jsrh9LiOBk89NQVvWA3vj25SqbYnbfhcJPhF+qvUz+onwN/BOgpf4j0v9bT/n+BPht4xb11S/kUn4xbkHaAD68RL3M8Ai4FFgO7AL+F+jPHcjsNP/pf8eb4+gN+Bb/nNfGfxHT9G2/ANQ4z//CWBlqrZl2D5PMjVXTyXiPfmO/55s89+TNan6ngC/AQ75bdkGbErhtpzlv1YncASoCVv9jBAa/vpPA3v8rxvj1i/Gu1psD16AZI1Xm6YRERGRwGbq1VMiInIKFBoiIhKYQkNERAJTaIiISGAKDRERCUyhIdOCmR2b4p/3zCS9zgVm1mZmL5tZrZn9c4DnXGlTNPOtyHAKDZERmNmYsyU45945iT/uKefceryBcJea2bvG2f9KQKEhSTETpxGRGcLMlgC3AuVAF/BZ51ytmV2GN59TJt4grY865w6Z2d8Ac/FG4h42s9eA+XgDoOYD/+qc+67/2secc/lmdgHwN3hTMpyGNyXFx5xzzsw+iDcI9DDwErDYOXfpaPU657rNbBtvT8D4WeAmv849wMeBM/CmSz/fzP4KuMZ/+gntPNV/N5GxqKch09los4I+DZzjf7q/F/hvcc85E7jCOXeDv7wSb2rps4G/NrOMEX7OeuBLeJ/+FwPvMrNs4Id492J4N94f9DGZWQneHGZb/FW/dM6d5ZxbB+wGPuOcewZvjqSvOefOcM7tHaOdIpNOPQ2ZlsaZ1bQa+Kl/74RMvLl4Bm1yznXHLT/kvEn2es2sCZjD8VNnAzzvnKvzf+42vJ7KMWCfc27wte/B6zWM5Dwz2wGsAL7hnGv0159mZv8HKAbygc0n2U6RSafQkOlq1FlBgX8DvuWc2xR3eGlQ57B942dljTLy/5mR9hlp2unRPOWcu9TMlgNPm9n9zrlteBPjXemc225mn+L4mUoHjdVOkUmnw1MyLTnv7ndvmNmHAcyzzt9cBBz0HyfqPhW1wGIzW+gv//F4T3DOvYY3geRf+KsKgAb/kNhH43bt8LeN106RSafQkOki18zq4r7+HO8P7WfMbDverLFX+Pv+Dd7hnKfwTlJPOv8Q1y3Ao2b2NN5Mr20Bnnob8B4zWwT8T+APeLfnjD+xfS/wNf8y3SWM3k6RSadZbkUSxMzynXPH/Hsx3wq87pz7drLrEpkI9TREEuez/onxGrxDYj9MbjkiE6eehoiIBKaehoiIBKbQEBGRwBQaIiISmEJDREQCU2iIiEhg/x8QxaKUz7syzQAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "learn.recorder.plot(show_moms=True, suggestion=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "learning_rate = 2.75e-2\n",
    "epochs = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: left;\">\n",
       "      <th>epoch</th>\n",
       "      <th>train_loss</th>\n",
       "      <th>valid_loss</th>\n",
       "      <th>accuracy</th>\n",
       "      <th>time</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>0.419287</td>\n",
       "      <td>0.441445</td>\n",
       "      <td>0.815932</td>\n",
       "      <td>02:20</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "start = time()\n",
    "learn.fit_one_cycle(epochs,learning_rate)\n",
    "t_final = time() - start "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "140.42356204986572"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "t_final"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "file_extension": ".py",
  "kernelspec": {
   "display_name": "rapids",
   "language": "python",
   "name": "myenv"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.8"
  },
  "mimetype": "text/x-python",
  "name": "python",
  "npconvert_exporter": "python",
  "pygments_lexer": "ipython3",
  "version": 3
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
